import torch
from peft import peft_model
import geoopt
from tqdm import tqdm

class TSPA:
    def __init__(
        self,
        peft_model: peft_model.PeftModel,
        adapter_names: list[str],
        weights: list[float],
        device: str
    ):
        self.peft_model = peft_model
        self.adapter_names = adapter_names
        self.weights = weights
        self.device = device

    def stiefel_optimize_lora(self, N, Bs, As, learning_rate=5e-2, training_steps=2000):
        BAs = [Bs[i] @ As[i].T for i in range(N)]
        Rs = [geoopt.ManifoldParameter(
            torch.eye(Bs[0].shape[1]).requires_grad_().to(self.device), 
            manifold=geoopt.Stiefel(),
            requires_grad=True)
            for _ in range(N)]
        
        optimizer = geoopt.optim.RiemannianAdam(Rs, lr=learning_rate, weight_decay=0.01, amsgrad=True)

        def loss_fn():
            loss = 0
            diff = sum(self.weights[i] * Bs[i] @ Rs[i] for i in range(N)) @ sum(self.weights[i] * As[i] @ Rs[i] for i in range(N)).T
            for i in range(N):
                loss += ((diff - BAs[i]) ** 2).sum()
            return loss

        for step in tqdm(range(training_steps), desc=f"Training..."):
            optimizer.zero_grad()
            loss = loss_fn()
            loss.backward(retain_graph=True)
            optimizer.step()

            if step % 100 == 0 or (step + 1) == training_steps:
                print(f"Step {step}: Loss = {loss.item()}")
        return Rs
    
    def stiefel_optimize(self, N, Q, K, Q_lora_Bs, Q_lora_As, K_lora_Bs, K_lora_As, learning_rate=5e-2, training_steps=100):
        delta_Qs = [Q_lora_Bs[i] @ Q_lora_As[i] for i in range(N)]
        delta_Ks = [K_lora_Bs[i] @ K_lora_As[i] for i in range(N)]
        QKs = [Q @ K.T for _ in range(N)]
        Rs = [geoopt.ManifoldParameter(
            torch.eye(delta_Qs[0].shape[1]).requires_grad_().to(self.device), 
            manifold=geoopt.Stiefel(),
            requires_grad=True)
            for _ in range(N)]

        optimizer = geoopt.optim.RiemannianAdam(Rs, lr=learning_rate, weight_decay=0.01, amsgrad=True)

        def loss_fn():
            loss = 0
            Q_lora_As_ = [None] * N
            K_lora_As_ = [None] * N
            for i in range(N):
                Q_lora_As_[i] = Q_lora_As[i] @ Rs[i]
                K_lora_As_[i] = K_lora_As[i] @ Rs[i]
            weighted_delta_Q = sum([self.weights[i] * Q_lora_Bs[i] for i in range(N)]) @ sum([self.weights[i] * Q_lora_As_[i] for i in range(N)])
            weighted_delta_K = sum([self.weights[i] * K_lora_Bs[i] for i in range(N)]) @ sum([self.weights[i] * K_lora_As_[i] for i in range(N)])
            Q_ = Q + weighted_delta_Q
            K_ = K + weighted_delta_K
            QK_ = Q_ @ K_.T
            for i in range(N):
                loss += ((QK_ - QKs[i]) ** 2).sum()
            return loss

        for step in tqdm(range(training_steps), desc=f"Training..."):
            optimizer.zero_grad()
            loss = loss_fn()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            if step % 10 == 0 or (step + 1) == training_steps:
                print(f"Step {step}: Loss = {loss.item()}")
        return Rs

    def align_multiple_adapters(self) -> None:
        layers = self.peft_model.model.model.layers
        target_modules = self.peft_model.peft_config[self.adapter_names[0]].target_modules
        N = len(self.adapter_names)
        r = self.peft_model.peft_config[self.adapter_names[0]].r
        
        for layer_idx, layer in enumerate(layers, start=1):
            for (proj_name1, proj_name2) in [
                ("q_proj", "k_proj"),
                ("v_proj", "o_proj"),
            ]:
                q_proj = getattr(layer.self_attn, proj_name1)
                k_proj = getattr(layer.self_attn, proj_name2)
                Q = q_proj.weight
                K = k_proj.weight
                Q_lora_Bs = [q_proj.lora_B[name].weight
                    if proj_name1 in target_modules else torch.zeros(q_proj.weight.shape[0], r).to(self.device)
                    for name in self.adapter_names]
                Q_lora_As = [q_proj.lora_A[name].weight
                    if proj_name1 in target_modules else torch.zeros(r, q_proj.weight.shape[1]).to(self.device)
                    for name in self.adapter_names]
                K_lora_Bs = [k_proj.lora_B[name].weight
                    if proj_name2 in target_modules else torch.zeros(k_proj.weight.shape[0], r).to(self.device)
                    for name in self.adapter_names]
                K_lora_As = [k_proj.lora_A[name].weight
                    if proj_name2 in target_modules else torch.zeros(r, k_proj.weight.shape[1]).to(self.device)
                    for name in self.adapter_names]
                
                # Attention alignment
                Rs = self.stiefel_optimize(N, Q, K, Q_lora_Bs, Q_lora_As, K_lora_Bs, K_lora_As)
                print(f"Aligning {proj_name1} and {proj_name2} in layer {layer_idx}...")
                for i, R in enumerate(Rs):
                    adapter_name = self.adapter_names[i]
                    if proj_name1 in target_modules:
                        q_proj.lora_A[adapter_name].weight.data = q_proj.lora_A[adapter_name].weight.data @ R
                    if proj_name2 in target_modules:
                        k_proj.lora_A[adapter_name].weight.data = k_proj.lora_A[adapter_name].weight.data @ R

                # LoRA alignment
                for proj_name, proj in zip((proj_name1, proj_name2), (q_proj, k_proj)):
                    if proj_name in target_modules:
                        lora_Bs = [proj.lora_B[name].weight for name in self.adapter_names]
                        lora_As = [proj.lora_A[name].weight.T for name in self.adapter_names]

                        print(f"Optimizing LoRA matrices for {proj_name} in layer {layer_idx}...")
                        Rs = self.stiefel_optimize_lora(N, lora_Bs, lora_As)
                        for adapter_name, R in zip(self.adapter_names, Rs):
                            proj.lora_B[adapter_name].weight.data = proj.lora_B[adapter_name].weight.data @ R
                            proj.lora_A[adapter_name].weight.data = R.T @ proj.lora_A[adapter_name].weight.data
                print()
            print(f"Layer {layer_idx} processed.")
            print("-----------------------------------")
        return

    def compute_aligned_adapter(self) -> None:
        if len(self.adapter_names) < 2:
            raise ValueError("At least two adapters are required for alignment.")
        else:
            self.align_multiple_adapters()
